Packages for this section
library(tidyverse)
library(latex2exp)
library(jsonlite)In order to expose systemic disparities, section 6 of the main paper propose to partition policyholders following a relevant fairness local metric. In this section, we aim to perform partitioning of policyholders following proxy vulnerability (section 6.1 of the main paper) and following commercial loading (section 6.2 of the main paper).
Furthermore, we perform an experiment on Scenario 1 to see how effective is partitioning in identifying high proxy vulnerability policyholders. We gain insight on best practices for partitioning through this repeated partitioning experiment.
This section is based on simulations of sec-simul-dataset and estimated metrics from sec-local, and it will serve to build the integrated framework of sec-integrated.
#### EVTREE ESSENTIALS
# Function to train and evaluate a model
evaluate_model <- function(params, train_data, valid_data, response_name) {
tryCatch({
flog.info("Training model for response: %s with params: %s", response_name, toString(params))
beg_evtree <- Sys.time()
# Train evtree model
evtree_model <- evtree(
resp ~ X1 + X2 + D,
data = train_data,
control = evtree.control(
minsplit = params$minbucket * 2 + 1,
minbucket = params$minbucket,
maxdepth = params$maxdepth,
ntrees = params$ntrees,
alpha = params$alpha
))
time_evtree <- as.numeric(difftime(Sys.time(), beg_evtree, units = "secs"))
beg_rpart <- Sys.time()
# Train rpart (deep)
rpart_model <- rpart(
resp ~ X1 + X2 + D,
data = train_data,
method = 'anova',
control = rpart.control(
minsplit = params$minbucket * 2 + 1,
minbucket = params$minbucket,
maxdepth = params$maxdepth,
cp = 0.00001
)
)
time_rpart <- as.numeric(difftime(Sys.time(), beg_rpart, units = "secs"))
# Prune rpart model to match the number of leaves in evtree
evtree_leaves <- length(unique(predict(evtree_model, type = "node")))
rpart_cptable <- rpart_model$cptable
# Find the largest CP for nsplit <= evtree_leaves - 1, or use cp = 0 if no match
min_xerror_index <- which.min(rpart_cptable[, "xerror"])
matching_cp <- rpart_cptable[min_xerror_index, "CP"]
pruned_rpart_model <- prune(rpart_model, cp = matching_cp)
# Prune rpart model to have at most 8 leaves
target_8_splits <- 8 - 1
matching_8_cp <- rpart_cptable[which.min(abs(rpart_cptable[, "nsplit"] - target_8_splits)), 'CP']
pruned_8_rpart_model <- prune(rpart_model, cp = matching_8_cp)
# Evaluate evtree on validation set
n <- nrow(valid_data)
evtree_preds <- predict(evtree_model, newdata = valid_data)
evtree_mse <- n * log(sum((valid_data$resp - evtree_preds)^2, na.rm = TRUE)/ n) + length(unique(evtree_preds)) * log(n)
flog.info("Validation MSE for evtree (response: %s): %f", response_name, evtree_mse)
# Evaluate pruned rpart on validation set
rpart_preds <- predict(pruned_rpart_model, newdata = valid_data, type = "vector")
rpart_mse <- n * log(sum((valid_data$resp - rpart_preds)^2, na.rm = TRUE)/ n) + length(unique(rpart_preds)) * log(n)
flog.info("Validation MSE for rpart (response: %s): %f", response_name, rpart_mse)
# Evaluate pruned rpart on validation set
rpart_8preds <- predict(pruned_8_rpart_model, newdata = valid_data, type = "vector")
rpart_8mse <- n * log(sum((valid_data$resp - rpart_8preds)^2, na.rm = TRUE)/ n) + length(unique(rpart_8preds)) * log(n)
flog.info("Validation MSE for 8rpart (response: %s): %f", response_name, rpart_8mse)
# Return results for both models
list(
evtree = list(model = evtree_model, val_mse = evtree_mse,
leaves = evtree_leaves,
evtree_val_preds = evtree_preds,
time = time_evtree),
rpart = list(model = pruned_rpart_model,
val_mse = rpart_mse,
leaves = length(unique(round(rpart_preds, 3))),
time = time_rpart,
rpart_val_preds = rpart_preds,
cp = matching_cp),
rpart_eight = list(model = pruned_8_rpart_model,
val_mse = rpart_8mse,
leaves = length(unique(round(rpart_8preds, 3))),
time = time_rpart,
rpart_val_preds = rpart_8preds,
cp = matching_8_cp),
response_name = response_name,
params = params
)
}, error = function(e) {
flog.error("Error while training model: %s", e$message)
NULL
})
}
recode_bottom_top_middle <- function(levels, numeric = FALSE, n_bottom = 3, n_top = 3, factor_values = NULL) {
# Check if input is valid
if (is.null(levels) || length(levels) < 1) {
stop("Input levels must be a non-empty numeric vector.")
}
# Ensure levels are numeric
levels <- as.numeric(levels)
# If numeric = TRUE, factor_values must be provided
if (numeric && (is.null(factor_values) || length(factor_values) < 1)) {
stop("For numeric = TRUE, a valid 'factor_values' vector must be provided.")
}
# Handle cases where the total number of levels is insufficient for grouping
if (length(levels) <= (n_bottom + n_top + 1)) {
return(factor(levels, levels = as.character(levels))) # No middle group possible
}
# Define the bottom and top ranges dynamically based on the original order
bottom <- levels[1:n_bottom] # First `n_bottom` values
top <- levels[(length(levels) - n_top + 1):length(levels)] # Last `n_top` values
middle <- levels[!(levels %in% c(bottom, top))] # Remaining middle
# Dynamically create label for the middle group
if (length(middle) > 1) {
if (numeric) {
# Subset the factor_values for rows where the levels belong to the middle group
middle_values <- factor_values[factor_values %in% middle]
middle_mean <- round(mean(as.numeric(as.character(middle_values)), na.rm = TRUE), 3) # Compute mean
middle_label <- middle_mean
} else {
middle_label <- paste0(min(middle), "-", max(middle)) # Range as label
}
} else if (length(middle) == 1) {
middle_label <- as.character(middle) # Single value retains its original label
} else {
middle_label <- NULL
}
# Create a mapping for novel levels
recoded <- sapply(levels, function(x) {
if (x %in% bottom) return(as.character(x)) # Keep bottom levels as is
if (x %in% middle) return(as.character(middle_label)) # Group middle levels dynamically
if (x %in% top) return(as.character(x)) # Keep top levels as is
})
# Return as a factor with ordered levels maintaining the original order
return(factor(recoded, levels = unique(recoded), ordered = TRUE))
}
# Function to extract paths for all terminal nodes from any tree
extract_paths_for_all_terminals <- function(tree) {
# Convert the tree to a party object (if not already)
party_tree <- tree #as.party(tree)
# Get terminal node IDs
terminal_nodes <- nodeids(party_tree, terminal = TRUE)
extract_path_for_terminal <- function(party_tree, terminal_node_index) {
# Get the root node of the tree
root_node <- party_tree$node
# Retrieve variable names from the tree's data
var_names <- names(party_tree$data)
# Get terminal node IDs
terminal_nodes <- partykit::nodeids(party_tree, terminal = TRUE)
# Ensure the requested terminal node index is valid
if (terminal_node_index > length(terminal_nodes) || terminal_node_index < 1) {
stop("Terminal node index exceeds the number of terminal nodes.")
}
# Get the node ID corresponding to the terminal node index
terminal_node_id <- terminal_nodes[terminal_node_index]
extract_operator_from_second_line <- function(node) {
# Capture the printed output as a character vector
node_text <- capture.output(print(node))
# Ensure there is a second line in the printed output
if (length(node_text) < 2) {
stop("The printed output does not contain enough lines to extract the operator.")
}
# Extract the second line
second_line <- node_text[2]
# Check for the presence of "<" or ">="
if (grepl("<", second_line)) {
return("<")
} else if (grepl(">=", second_line)) {
return(">=")
} else {
stop("No valid operator (< or >=) found in the second line.")
}
}
# Recursive function to traverse the tree and build the path
traverse_tree <- function(node, path = NULL) {
# If this is the designated terminal node, return the path
if (node$id == terminal_node_id) {
return(path)
}
# If there are no children (terminal node) and it's not the designated one, return NULL
if (is.null(node$kids)) {
return(NULL)
}
# Extract split information
split_variable <- var_names[node$split$varid]
split_value <- node$split$breaks
is_categ <- FALSE
# Check which child corresponds to the "left" (< split_value) branch
left_child_is_kid1 <- grepl('<', extract_operator_from_second_line(print(node)))
if (left_child_is_kid1) {
# Traverse the left child
if (terminal_node_id %in% nodeids(node$kids[[1]])) {
left_path <- c(path, paste0(split_variable, " < ", split_value))
return(traverse_tree(node$kids[[1]], left_path))
}
# Traverse the right child
if (terminal_node_id %in% nodeids(node$kids[[2]])) {
right_path <- c(path, paste0(split_variable, " >= ", split_value))
return(traverse_tree(node$kids[[2]], right_path))
}
} else {
# Traverse the right child
if (terminal_node_id %in% nodeids(node$kids[[1]])) {
right_path <- c(path, paste0(split_variable, " >= ", split_value))
return(traverse_tree(node$kids[[1]], right_path))
}
# Traverse the left child
if (terminal_node_id %in% nodeids(node$kids[[2]])) {
left_path <- c(path, paste0(split_variable, " < ", split_value))
return(traverse_tree(node$kids[[2]], left_path))
}
}
return(NULL) # No path found
}
# Start traversing from the root node
path <- traverse_tree(root_node)
# Return the path as a list of conditions
return(path)
}
# Loop through each terminal node and extract the path
paths <- lapply(seq_along(terminal_nodes), function(index) {
extract_path_for_terminal(party_tree, index)
})
# Combine the paths with their corresponding terminal node IDs
names(paths) <- terminal_nodes
simplified_paths = setNames(nm = names(paths)) %>% lapply(function(name_p){
simplify_path(path_conditions =paths[[name_p]], round_digits = 2)
})
return(list('full' = paths,
'full_concat' = lapply(paths, function(p) paste(p, collapse = ' AND ')),
'simp' = simplified_paths,
'simp_concat' = lapply(simplified_paths, function(p) paste(p, collapse = ' AND '))
)
)
}
simplify_path <- function(path_conditions, round_digits = NULL) {
# Parse conditions into variable, operator, and value
parsed_conditions <- lapply(path_conditions, function(cond) {
matches <- regmatches(cond, regexec("^(\\w+)\\s*([<>]=?)\\s*(.*)$", cond))[[1]]
list(var = matches[2], op = matches[3], val = as.numeric(matches[4]))
})
# Group conditions by variable
grouped_conditions <- split(parsed_conditions, sapply(parsed_conditions, function(x) x$var))
# Simplify conditions for each variable
simplified_conditions <- lapply(grouped_conditions, function(conds) {
# Separate "<" and ">" conditions
less_than <- conds[sapply(conds, function(x) x$op %in% c("<", "<="))]
greater_than <- conds[sapply(conds, function(x) x$op %in% c(">", ">="))]
# Process the most restrictive "<" condition
if (length(less_than) > 0) {
max_less_than <- less_than[[which.min(sapply(less_than, function(x) x$val))]]
} else {
max_less_than <- NULL
}
# Process the most restrictive ">" condition
if (length(greater_than) > 0) {
max_greater_than <- greater_than[[which.max(sapply(greater_than, function(x) x$val))]]
} else {
max_greater_than <- NULL
}
# Apply rounding if specified
if (!is.null(round_digits)) {
if (!is.null(max_less_than)) max_less_than$val <- round(max_less_than$val, round_digits)
if (!is.null(max_greater_than)) max_greater_than$val <- round(max_greater_than$val, round_digits)
}
# Recombine conditions into a valid range or inequality
if (!is.null(max_less_than) && !is.null(max_greater_than)) {
if (max_greater_than$val >= max_less_than$val) {
stop("Conflicting conditions for variable ", max_less_than$var)
}
paste0(max_greater_than$val, " < ", max_greater_than$var, " < ", max_less_than$val)
} else if (!is.null(max_less_than)) {
paste0(max_less_than$var, " ", max_less_than$op, " ", max_less_than$val)
} else if (!is.null(max_greater_than)) {
paste0(max_greater_than$var, " ", max_greater_than$op, " ", max_greater_than$val)
}
})
# Flatten the simplified conditions
return(unlist(simplified_conditions))
}We apply an optimal partitioning algorithm, evtree from Grubinger, Zeileis, and Pfeiffer (2014), to policyholders based on proxy vulnerability in the three scenarios of the example. We use \((X_1, X_2)\) as the feature space for partitioning and impose strong regularization to limit the number of groups.
Figure fig-ex_partitioning_clusters presents the results for the three scenarios. The top row shows estimated proxy vulnerability, with colors indicating the groups resulting from the optimal partition of proxy vulnerability. While the left panel may not match intuition in terms of the number of groups in scenario 1, the predicted values of proxy vulnerability based on the evtree align with expectations: darker red indicates individuals most vulnerable to proxy effects. The bottom row of Figure fig-ex_partitioning_clusters depicts the partition in the \((x_1, x_2)\) domain. The structure aligns with the example design: high proxy vulnerability for individuals with \(x_2 = 4\) and large \(x_1\), and important variation in proxy vulnerability across \(x_2\).
source("___train_evtree_scenario.R")
pregroup_pop_stats_small <- setNames(nm = names(pregroup_pop_stats)) %>% lapply(function(pop_name){
setNames(nm = names(pregroup_pop_stats[[pop_name]])) %>% lapply(function(the_set){
the_frac <- ifelse(the_set == 'train', 0.1 , 1)
pregroup_pop_stats[[pop_name]][[the_set]] %>%
sample_frac(the_frac)
})
})
# Define hyperparameter grid
param_grid <- expand.grid(
minbucket = c(0.03, 0.05) * nrow(pregroup_pop_stats_small$Scenario1$train),
maxdepth = c(3, 4),
alpha = c(1, 2),
ntrees = 25,
stringsAsFactors = FALSE
)
output_dir <- "evtree" # Directory to save models
response_vars <- c("proxy_vuln", 'comm_load') # List of response variables
# Call process_populations with actual inputs
my_trees <- process_populations(preds_pop_stats = pregroup_pop_stats_small,
response_vars = response_vars,
param_grid = param_grid,
output_dir)INFO [2025-10-06 14:35:09] Processing response variable: proxy_vuln
INFO [2025-10-06 14:35:09] Processing population: Scenario1 for response: proxy_vuln
INFO [2025-10-06 14:35:09] Model for population Scenario1 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing population: Scenario2 for response: proxy_vuln
INFO [2025-10-06 14:35:10] Model for population Scenario2 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing population: Scenario3 for response: proxy_vuln
INFO [2025-10-06 14:35:10] Model for population Scenario3 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing response variable: comm_load
INFO [2025-10-06 14:35:10] Processing population: Scenario1 for response: comm_load
INFO [2025-10-06 14:35:10] Model for population Scenario1 and response comm_load already exists. Loading...
INFO [2025-10-06 14:35:12] Processing population: Scenario2 for response: comm_load
INFO [2025-10-06 14:35:12] Model for population Scenario2 and response comm_load already exists. Loading...
INFO [2025-10-06 14:35:12] Processing population: Scenario3 for response: comm_load
INFO [2025-10-06 14:35:12] Model for population Scenario3 and response comm_load already exists. Loading...
library(rpart)
library(ggparty, partykit)
temp_tree <- c('evtree', 'rpart') %>% lapply(function(the_algo){
names(pregroup_grid_stats) %>% lapply(function(pop_name){
pop_id <- which(names(pregroup_grid_stats) == pop_name)
# Compute sequential terminal node IDs
party_tree <- my_trees$proxy_vuln[[pop_name]][[paste0('best_', the_algo)]]$model
if (the_algo == 'rpart'){
party_tree <- partykit::as.party(party_tree)
}
terminal_ids <- nodeids(party_tree, terminal = TRUE) # Original terminal node IDs
sequential_ids <- seq_along(terminal_ids) # Create sequential IDs
id_mapping <- data.frame(terminal_id = terminal_ids, sequential_id = sequential_ids)
## Compute average prediction per terminal node
# Extract predictions and terminal node IDs
predictions <- fitted(party_tree)
avg_prediction <- aggregate(`(response)` ~ `(fitted)`,
data = predictions,
FUN = mean)
tree_plot <- ggparty(party_tree) +
geom_edge() +
geom_edge_label(mapping = aes(label = !!sym("breaks_label")),
size = 3) +
geom_node_label(
line_list = list(
aes(label = splitvar),
aes(label = paste("N =", nodesize))
),
line_gpar = list(
list(size = 10),
list(size = 8)
),
ids = "inner",
) +
geom_node_label(
line_list = list(
aes(label = paste0("Node ",
match(id, id_mapping$terminal_id),
", N = ",
nodesize)),
aes(label = paste0("Avg Pred. = ",
round(avg_prediction$`(response)`[match(id, avg_prediction$`(fitted)`)], 2)))
),
line_gpar = list(
list(size = 8),
list(size = 10)
),
ids = "terminal", nudge_y = -0.45, nudge_x = 0.01,
label.size = 0.15,
size = 3
) +
geom_node_plot(
gglist = list(
geom_boxplot(aes(x = "", y = resp,
color = ..middle..,
fill = ..middle..), # Color by median
outlier.color = "black"
, alpha = 0.7
),
theme_minimal(),
scale_fill_gradient2(
low = "#D7CC39", mid = "grey75", high = "#CAA8F5",
midpoint = 0, name = "Median Value"
),
scale_color_gradient2(
low = colorspace::darken("#D7CC39", 0.3), mid = colorspace::darken("grey75", 0.3),
high = colorspace::darken("#CAA8F5", 0.3),
midpoint = 0, name = "Median Value"
),
xlab(""), ylab(latex2exp::TeX("$\\widehat{\\Delta}_{proxy}(X_1, X_2)$")),
scale_y_continuous(labels = scales::dollar),
theme(axis.text.x = element_blank(),
axis.title.y = element_text(margin = margin(l = -10)),
axis.title.x = element_text(margin = margin(r = 20)))
),
shared_axis_labels = TRUE
) +
ggtitle(latex2exp::TeX(paste0('Partition of proxy vulnerable individuals for scenario ', pop_id))) +
theme(
plot.title = element_text(size = 16, face = "bold", hjust = 0.5)
)
}) %>% ggpubr::ggarrange(plotlist = .,
nrow = 3,
widths = 15, heights = 1,
common.legend = T,
legend = 'right') %>%
ggsave(filename = paste0("figs/graph_trees_", the_algo,"_proxy.png"),
plot = .,
height = 16,
width = 12,
units = "in",
device = "png", dpi = 500)
})
rm(temp_tree)dictionnary_leaves_trees <- setNames(nm = names(my_trees)) %>% lapply(function(resp_tree){
setNames(nm = names(my_trees[[resp_tree]])) %>% lapply(function(pop_name){
temp_to_pred <- pregroup_grid_stats[[pop_name]]
names(temp_to_pred) <- toupper(names(temp_to_pred))
model_ev <- my_trees[[resp_tree]][[pop_name]]$best_evtree$model
model_rpart <- my_trees[[resp_tree]][[pop_name]]$best_rpart$model
model_rpart_prune <- prune(model_rpart,
cp = model_rpart$cptable[which(model_rpart$cptable[, "nsplit"] + 1 == 8), "CP"]) %>% as.party()
model_rpart <- model_rpart %>% as.party()
to_return_evtree <- data.frame('node_or' = predict(model_ev, newdata = temp_to_pred,
type = 'node') %>% unname,
'pred' = predict(model_ev, newdata = temp_to_pred,
type = 'response') %>% unname %>%
round(3)) %>% distinct() %>% arrange(-pred) %>%
mutate('node_new' = 1:n())
to_return_rpart <- data.frame('node_or' = predict(model_rpart, newdata = temp_to_pred,
type = 'node') %>% unname,
'pred' = predict(model_rpart, newdata = temp_to_pred, type = 'response') %>% unname %>%
round(3)) %>% distinct() %>% arrange(-pred) %>%
mutate('node_new' = 1:n())
to_return_rpart_prune <- data.frame('node_or' = predict(model_rpart_prune, newdata = temp_to_pred,
type = 'node') %>% unname,
'pred' = predict(model_rpart_prune, newdata = temp_to_pred,
type = 'response') %>% unname %>%
round(3)) %>% distinct() %>% arrange(-pred) %>%
mutate('node_new' = 1:n())
paths_ev <- extract_paths_for_all_terminals(tree = model_ev)
paths_rpart <- extract_paths_for_all_terminals(tree = model_rpart)
paths_rpart_prune <- extract_paths_for_all_terminals(tree = model_rpart_prune)
list('evtree' = list('dict' = to_return_evtree,
'model' = model_ev,
'paths' = paths_ev),
'rpart' = list('dict' = to_return_rpart,
'model' = model_rpart,
'paths' = paths_rpart),
'rpart_prune' = list('dict' = to_return_rpart_prune,
'model' = model_rpart_prune,
'paths' = paths_rpart_prune))
})
})
saveRDS(dictionnary_leaves_trees, 'evtree/dictionnary_leaves_trees.rds')
### Applying partition to the data
group_grid_path = 'preds/group_grid_stats.json'
group_pop_path = 'preds/group_pop_stats.json'
# Check and load or compute group_grid_stats
if (file.exists(group_grid_path)) {
temp_grid_stats <- fromJSON(group_grid_path)
group_grid_stats <- setNames(nm = names(temp_grid_stats)) |> lapply(function(pop_name){
temp_grid_stats[[pop_name]] |>
mutate(proxy_g_evtree = proxy_g_evtree %>% factor(., levels = sort(unique(as.numeric(proxy_g_evtree)), decreasing = T)),
proxy_g_rpart = proxy_g_rpart %>% factor(., levels = sort(unique(as.numeric(proxy_g_rpart)), decreasing = T)),
cload_g_evtree = cload_g_evtree %>% factor(., levels = sort(unique(as.numeric(cload_g_evtree)), decreasing = T)),
cload_g_rpart = cload_g_rpart %>% factor(., levels = sort(unique(as.numeric(cload_g_rpart)), decreasing = T)))
})
rm(temp_grid_stats)
} else {
group_grid_stats <- setNames(nm = names(pregroup_grid_stats)) %>% lapply(function(pop_name){
temp_to_pred <- pregroup_grid_stats[[pop_name]]
names(temp_to_pred) <- toupper(names(temp_to_pred))
pred_proxy_g_evtree <- predict(my_trees$proxy_vuln[[pop_name]]$best_evtree$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_proxy_g_rpart <- predict(my_trees$proxy_vuln[[pop_name]]$best_rpart$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_cload_g_evtree <- predict(my_trees$comm_load[[pop_name]]$best_evtree$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_cload_g_rpart <- predict(my_trees$comm_load[[pop_name]]$best_rpart$model,
newdata = temp_to_pred) %>% round(3) %>% unname
data.frame(pregroup_grid_stats[[pop_name]],
proxy_g_evtree = pred_proxy_g_evtree %>% factor(., levels = sort(unique(pred_proxy_g_evtree), decreasing = T)),
proxy_g_rpart = pred_proxy_g_rpart %>% factor(., levels = sort(unique(pred_proxy_g_rpart), decreasing = T)),
cload_g_evtree = pred_cload_g_evtree %>% factor(., levels = sort(unique(pred_cload_g_evtree), decreasing = T)),
cload_g_rpart = pred_cload_g_rpart %>% factor(., levels = sort(unique(pred_cload_g_rpart), decreasing = T))
)
})
toJSON(group_grid_stats, pretty = TRUE, auto_unbox = TRUE) %>%
write(group_grid_path)
}
# Check and load or compute group_pop_stats
if (file.exists(group_pop_path)) {
temp_pop_stats <- fromJSON(group_pop_path)
group_pop_stats <- setNames(nm = names(temp_pop_stats)) |> lapply(function(pop_name){
setNames(nm = names(temp_pop_stats[[pop_name]])) |> lapply(function(the_set){
temp_pop_stats[[pop_name]][[the_set]] |>
mutate(proxy_g_evtree = proxy_g_evtree %>% factor(., levels = sort(unique(as.numeric(proxy_g_evtree)), decreasing = T)),
proxy_g_rpart = proxy_g_rpart %>% factor(., levels = sort(unique(as.numeric(proxy_g_rpart)), decreasing = T)),
cload_g_evtree = cload_g_evtree %>% factor(., levels = sort(unique(as.numeric(cload_g_evtree)), decreasing = T)),
cload_g_rpart = cload_g_rpart %>% factor(., levels = sort(unique(as.numeric(cload_g_rpart)), decreasing = T)))
})
})
rm(temp_pop_stats)
} else {
group_pop_stats <- setNames(nm = names(pregroup_pop_stats)) %>% lapply(function(pop_name){
setNames(nm = names(pregroup_pop_stats[[pop_name]])) %>% lapply(function(set){
temp_to_pred <- pregroup_pop_stats[[pop_name]][[set]]
pred_proxy_g_evtree <- predict(my_trees$proxy_vuln[[pop_name]]$best_evtree$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_proxy_g_rpart <- predict(my_trees$proxy_vuln[[pop_name]]$best_rpart$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_cload_g_evtree <- predict(my_trees$comm_load[[pop_name]]$best_evtree$model,
newdata = temp_to_pred) %>% round(3) %>% unname
pred_cload_g_rpart <- predict(my_trees$comm_load[[pop_name]]$best_rpart$model,
newdata = temp_to_pred) %>% round(3) %>% unname
data.frame(pregroup_pop_stats[[pop_name]][[set]],
proxy_g_evtree = pred_proxy_g_evtree %>% factor(., levels = sort(unique(pred_proxy_g_evtree), decreasing = T)),
proxy_g_rpart = pred_proxy_g_rpart %>% factor(., levels = sort(unique(pred_proxy_g_rpart), decreasing = T)),
cload_g_evtree = pred_cload_g_evtree %>% factor(., levels = sort(unique(pred_cload_g_evtree), decreasing = T)),
cload_g_rpart = pred_cload_g_rpart %>% factor(., levels = sort(unique(pred_cload_g_rpart), decreasing = T))
)
})
})
toJSON(group_pop_stats, pretty = TRUE, auto_unbox = TRUE) %>%
write(group_pop_path)
}n_bottom <- 5
n_top <- 5
setNames(nm = names(group_pop_stats)) %>% lapply(function(pop_name){
## the colors
pop_id <- which(names(group_pop_stats) == pop_name)
local_to_g <- group_grid_stats[[pop_name]] %>%
filter(x1 <= 8, x1 >= -5, d == 1)
if(pop_name == head(names(group_grid_stats), 1)){
the_y_scale_top <- scale_y_continuous(labels = scales::dollar, breaks = c(-5, 0, 5, 10), limits = c(-6, 14))
the_y_label_top <- latex2exp::TeX("$\\Delta_{proxy}(x_1, x_2)$")
the_y_scale <- scale_y_discrete()
the_y_label <- latex2exp::TeX('$x_2$')
} else {
the_y_scale_top <- scale_y_continuous(labels = NULL, breaks = c(-5, 0, 5, 10), limits = c(-6, 14))
the_y_label_top <- NULL
the_y_scale <-scale_y_discrete(labels = NULL)
the_y_label <- NULL
}
local_pop_g <- group_pop_stats[[pop_name]]$valid
local_to_g$proxy_g_evtree_g <- local_to_g$proxy_g_evtree
local_pop_g$proxy_g_evtree_g <- local_pop_g$proxy_g_evtree
levels(local_to_g$proxy_g_evtree_g) <- recode_bottom_top_middle(levels = levels(local_to_g$proxy_g_evtree),
numeric = TRUE,
n_bottom = n_bottom,
n_top = n_top,
factor_values = local_to_g$proxy_g_evtree)
levels(local_pop_g$proxy_g_evtree_g) <- recode_bottom_top_middle(levels = levels(local_pop_g$proxy_g_evtree),
numeric = TRUE,
n_bottom = n_bottom,
n_top = n_top,
factor_values = local_pop_g$proxy_g_evtree)
levels(local_to_g$proxy_g_evtree_g)[n_top + 1] <- levels(local_pop_g$proxy_g_evtree_g)[n_top + 1]
g_proxy <- local_to_g %>%
mutate(code = paste0(x2, '_', as.numeric(proxy_g_evtree_g))) %>%
ggplot(aes(x = x1, y = proxy_vuln,
group = factor(code),
color = factor(proxy_g_evtree_g))) +
geom_line(aes(x = x1, y = proxy_vuln_t,
lty = factor(x2), group = factor(x2)),
color = 'black', size = 0.8) +
geom_line(size = 3, alpha = 0.78, lineend = "round", linejoin = "round") +
theme_classic() +
labs(x = latex2exp::TeX('$x_1$'),
y = the_y_label_top,
title = paste0('Scenario ', pop_id)) +
scale_color_manual(values = RColorBrewer::brewer.pal(n_bottom + n_top + 1, 'Spectral') %>% colorspace::darken(0.05),
name = latex2exp::TeX('$\\widehat{\\Delta}^{ev}_{proxy}(\\textbf{x})$'),
labels = levels(local_to_g$proxy_g_evtree_g) %>% as.numeric %>% round(2)) +
scale_linetype_manual(values = c('12', '21', '32', 'solid'), name = latex2exp::TeX('$x_2$')) +
the_y_scale_top +
# geom_abline(slope = 0, intercept = 0, lty = '34', color= 'black', size= 0.7, alpha = 0.2) +
scale_x_continuous(labels = NULL, breaks = c(-3:3)*3 + 1) + # see above
guides(
linetype = guide_legend(order = 1), # x2 legend on top
color = guide_legend(order = 2) # k legend below x2
)
g_population <- local_pop_g %>%
ggplot(aes(y = factor(X2), x = X1,
color = factor(proxy_g_evtree_g),
fill = factor(proxy_g_evtree_g))) +
geom_jitter(#position = position_identity(),
width = 0, height = 0.4, alpha = 0.2) +
scale_color_manual(values = RColorBrewer::brewer.pal(n_bottom + n_top + 1, 'Spectral') %>% colorspace::darken(0.05),
name = latex2exp::TeX('$\\widehat{\\Delta}^{ev}_{proxy}(\\textbf{x})$'),
labels = levels(local_to_g$proxy_g_evtree) %>% as.numeric %>% round(1)) +
scale_fill_brewer(palette = 'Spectral', name = latex2exp::TeX('$k$')) +
theme_classic() +
the_y_scale +
scale_x_continuous(breaks = c(-3:3)*3 + 1, limits = c(-5, 8)) +
labs(x = latex2exp::TeX("$x_1$") ,
y = the_y_label) +
theme( axis.title.y = element_text(
margin = margin(t = 50), # Add padding
))
ggpubr::ggarrange(g_proxy, g_population,
nrow = 2, common.legend = T,
legend = 'right',
heights = c(4, 3),
align = "v")
}) %>%
ggpubr::ggarrange(plotlist = .,
ncol = 3,
widths = c(6, 5, 5)) %>%
ggsave(filename = "figs/graph_proxy_clusters_and_pop_scenario.png",
plot = .,
height = 6.25,
width = 11.50,
units = "in",
device = "png", dpi = 500)To support the partitioning methodology proposed in section 6 of the main paper, we present a simulation study to identify best practices and gain insights into optimal implementation.
Following the set of equations of Scenario 1 in sec-simul-dataset, we simulate \(M = 100\) set of \(N = 3000\) samples split into \((N_{\text{train}}, N_{\text{valid}}, N_{\text{test}}) = (2000, 500, 500)\) for train, validation, and test. Our aim is to assess the capacity of the methodology to recover proxy-vulnerable distinct subpopulations, precisely identify the most at-risk groups, and predict accurately the proxy vulnerability. We compute the BIC (under Gaussian assumption) of the estimated proxy vulnerability as compared with the test theoretical proxy vulnerability, and the accuracy in the partitioning as compared with the true proxy vulnerable groups : the eight subpopulations formed by the crossing of \(\{X_1 \leq 1, X_1 > 1\}\) and \(X_2 \in \{1, 2, 3, 4\}\).
On each sample set, we estimated unaware and aware premiums using the methodology described sec-training and we computed proxy vulnerability \(\widehat{\Delta}_{\text{proxy}}(x_1, x_2)\) via Eq. 1 of the main paper. We partitioned the feature space using \((X_1, X_2, D)\). Models used rpart (locally optimal) and evtree (globally optimal) regression trees, with hyperparameters tuned via validation BIC (Gaussian assumption): minimum leaf size proportion \(w \in \{0.03, 0.05\}\), tree depth \(d \in \{3, 4\}\), and complexity parameter \(\alpha \in \{1, 2\}\). For rpart, we pruned a deep tree to minimize validation error. We compared performance under fixed (\(k = 8\)) and optimized leaf counts, retaining the best model per case for each implementation (four total).
if (!file.exists("preds/proxy_sims_results.csv")) {
source("___evtree_experiment.R")
# Load data
preds_sims_stats <- jsonlite::fromJSON("preds/preds_sims_stats.json")
# Process data
proxy_sims_results <- process_data_evtree_experiment(preds_sims_stats)
# Save results to CSV
write.csv(proxy_sims_results, "preds/proxy_sims_results.csv", row.names = FALSE)
cat("Results saved to preds/proxy_sims_results.csv")
} else {
proxy_sims_results <- read_csv("preds/proxy_sims_results.csv")
cat("Results read from preds/proxy_sims_results.csv")
}Results read from preds/proxy_sims_results.csv
The following table presents the results of our experiment. When the number of groups is correctly set to \(k = 8\), the \(8\times8\) accuracy and relaxed \(8\times8\) accuracy (which counts adjacent diagonals in the confusion matrix as correct) confirm that the method effectively groups individuals, with little difference between and . However, when \(k\) is unknown, \(k = 8\) was never seen as optimal. Validation metrics (based on estimated proxy vulnerability) favor larger \(k\), while oracle performance (based on theoretical proxy vulnerability) suggests better \(R^2\) at \(k = 8\). Whether \(k\) is known or not, partitioning identifies truly vulnerable individuals (top 12% by theoretical proxy vulnerability) with over 93% precision, fulfilling its primary objective. Since \(k\) is unknown in practice, strong regularization is essential to prevent excessive partitioning.
summary_results <- proxy_sims_results %>%
summarise(
# Metrics for "known k" (k = 8) evtree
k_8 = mean(proxy_sims_results$num_leaf_8, na.rm = TRUE),
k_8_sd = sd(proxy_sims_results$num_leaf_8, na.rm = TRUE),
minbucket_8 = mean(proxy_sims_results$minbucket_8, na.rm = TRUE),
minbucket_8_sd = sd(proxy_sims_results$minbucket_8, na.rm = TRUE),
maxdepth_8 = mean(proxy_sims_results$maxdepth_8, na.rm = TRUE),
maxdepth_8_sd = sd(proxy_sims_results$maxdepth_8, na.rm = TRUE),
alpha_8 = mean(proxy_sims_results$alpha_8, na.rm = TRUE),
alpha_8_sd = sd(proxy_sims_results$alpha_8, na.rm = TRUE),
validation_mse_8 = mean(proxy_sims_results$validation_mse_8, na.rm = TRUE),
validation_mse_8_sd = sd(proxy_sims_results$validation_mse_8, na.rm = TRUE),
oracle_r2_8 = mean(proxy_sims_results$r2_oracle_8, na.rm = TRUE),
oracle_r2_8_sd = sd(proxy_sims_results$r2_oracle_8, na.rm = TRUE),
acc_test_8 = mean(proxy_sims_results$accuracy_test_8, na.rm = TRUE),
acc_test_8_sd = sd(proxy_sims_results$accuracy_test_8, na.rm = TRUE),
relaxed_acc_test_8 = mean(proxy_sims_results$relaxed_accuracy_test_8,
na.rm = TRUE),
relaxed_acc_test_8_sd = sd(proxy_sims_results$relaxed_accuracy_test_8,
na.rm = TRUE),
top_acc_8 = mean(proxy_sims_results$top_acc_8, na.rm = TRUE),
top_acc_8_sd = sd(proxy_sims_results$top_acc_8, na.rm = TRUE),
bottom_acc_8 = mean(proxy_sims_results$bottom_acc_8, na.rm = TRUE),
bottom_acc_8_sd = sd(proxy_sims_results$bottom_acc_8, na.rm = TRUE),
time_8 = mean(proxy_sims_results$time_8, na.rm = TRUE), # Time assumed same for any/known
time_8_sd = sd(proxy_sims_results$time_8, na.rm = TRUE),
# Metrics for "known k" (k = 8) rpart
k_r8 = mean(proxy_sims_results$num_leaf_r8, na.rm = TRUE),
k_r8_sd = sd(proxy_sims_results$num_leaf_r8, na.rm = TRUE),
minbucket_r8 = mean(proxy_sims_results$minbucket_r8, na.rm = TRUE),
minbucket_r8_sd = sd(proxy_sims_results$minbucket_r8, na.rm = TRUE),
maxdepth_r8 = mean(proxy_sims_results$maxdepth_r8, na.rm = TRUE),
maxdepth_r8_sd = sd(proxy_sims_results$maxdepth_r8, na.rm = TRUE),
cp_8 = mean(proxy_sims_results$cp_8, na.rm = TRUE),
cp_8_sd = sd(proxy_sims_results$cp_8, na.rm = TRUE),
validation_mse_r8 = mean(proxy_sims_results$validation_mse_r8, na.rm = TRUE),
validation_mse_r8_sd = sd(proxy_sims_results$validation_mse_r8, na.rm = TRUE),
oracle_r2_r8 = mean(proxy_sims_results$r2_oracle_r8, na.rm = TRUE),
oracle_r2_r8_sd = sd(proxy_sims_results$r2_oracle_r8, na.rm = TRUE),
acc_test_r8 = mean(proxy_sims_results$accuracy_test_r8, na.rm = TRUE),
acc_test_r8_sd = sd(proxy_sims_results$accuracy_test_r8, na.rm = TRUE),
relaxed_acc_test_r8 = mean(proxy_sims_results$relaxed_accuracy_test_r8, na.rm = TRUE),
relaxed_acc_test_r8_sd = sd(proxy_sims_results$relaxed_accuracy_test_r8, na.rm = TRUE),
top_acc_r8 = mean(proxy_sims_results$top_acc_r8, na.rm = TRUE),
top_acc_r8_sd = sd(proxy_sims_results$top_acc_r8, na.rm = TRUE),
bottom_acc_r8 = mean(proxy_sims_results$bottom_acc_r8, na.rm = TRUE),
bottom_acc_r8_sd = sd(proxy_sims_results$bottom_acc_r8, na.rm = TRUE),
time_r8 = mean(proxy_sims_results$time_r8, na.rm = TRUE), # Time assumed same for any/known
time_r8_sd = sd(proxy_sims_results$time_r8, na.rm = TRUE),
# Metrics for "any k" rpart
k_rany = mean(proxy_sims_results$num_leaf_rany, na.rm = TRUE),
k_rany_sd = sd(proxy_sims_results$num_leaf_rany, na.rm = TRUE),
minbucket_rany = mean(proxy_sims_results$minbucket_rany, na.rm = TRUE),
minbucket_rany_sd = sd(proxy_sims_results$minbucket_rany, na.rm = TRUE),
maxdepth_rany = mean(proxy_sims_results$maxdepth_rany, na.rm = TRUE),
maxdepth_rany_sd = sd(proxy_sims_results$maxdepth_rany, na.rm = TRUE),
cp_any = mean(proxy_sims_results$cp_any, na.rm = TRUE),
cp_any_sd = sd(proxy_sims_results$cp_any, na.rm = TRUE),
validation_mse_rany = mean(proxy_sims_results$validation_mse_rany, na.rm = TRUE),
validation_mse_rany_sd = sd(proxy_sims_results$validation_mse_rany, na.rm = TRUE),
oracle_r2_rany = mean(proxy_sims_results$r2_oracle_rany, na.rm = TRUE),
oracle_r2_rany_sd = sd(proxy_sims_results$r2_oracle_rany, na.rm = TRUE),
top_acc_rany = mean(proxy_sims_results$top_acc_rany, na.rm = TRUE),
top_acc_rany_sd = sd(proxy_sims_results$top_acc_rany, na.rm = TRUE),
bottom_acc_rany = mean(proxy_sims_results$bottom_acc_rany, na.rm = TRUE),
bottom_acc_rany_sd = sd(proxy_sims_results$bottom_acc_rany, na.rm = TRUE),
time_rany = mean(proxy_sims_results$time_rany, na.rm = TRUE), # Time is shared, no distinction
time_rany_sd = sd(proxy_sims_results$time_rany, na.rm = TRUE),
# Metrics for "any k" evtree
k_any = mean(proxy_sims_results$num_leaf_any, na.rm = TRUE),
k_any_sd = sd(proxy_sims_results$num_leaf_any, na.rm = TRUE),
minbucket_any = mean(proxy_sims_results$minbucket_any, na.rm = TRUE),
minbucket_any_sd = sd(proxy_sims_results$minbucket_any, na.rm = TRUE),
maxdepth_any = mean(proxy_sims_results$maxdepth_any, na.rm = TRUE),
maxdepth_any_sd = sd(proxy_sims_results$maxdepth_any, na.rm = TRUE),
alpha_any = mean(proxy_sims_results$alpha_any, na.rm = TRUE),
alpha_any_sd = sd(proxy_sims_results$alpha_any, na.rm = TRUE),
validation_mse_any = mean(proxy_sims_results$validation_mse_any, na.rm = TRUE),
validation_mse_any_sd = sd(proxy_sims_results$validation_mse_any, na.rm = TRUE),
oracle_r2_any = mean(proxy_sims_results$r2_oracle_any, na.rm = TRUE),
oracle_r2_any_sd = sd(proxy_sims_results$r2_oracle_any, na.rm = TRUE),
top_acc_any = mean(proxy_sims_results$top_acc_any, na.rm = TRUE),
top_acc_any_sd = sd(proxy_sims_results$top_acc_any, na.rm = TRUE),
bottom_acc_any = mean(proxy_sims_results$bottom_acc_any, na.rm = TRUE),
bottom_acc_any_sd = sd(proxy_sims_results$bottom_acc_any, na.rm = TRUE),
time_any = mean(proxy_sims_results$time_any, na.rm = TRUE), # Time is shared, no distinction
time_any_sd = sd(proxy_sims_results$time_any, na.rm = TRUE),
# Add new metrics for eo, e8, ro, r8
recall_top_pct_eo = mean(proxy_sims_results$recall_top_pct_eo, na.rm = TRUE),
recall_top_pct_eo_sd = sd(proxy_sims_results$recall_top_pct_eo, na.rm = TRUE),
prec_top_pct_eo = mean(proxy_sims_results$prec_top_pct_eo, na.rm = TRUE),
prec_top_pct_eo_sd = sd(proxy_sims_results$prec_top_pct_eo, na.rm = TRUE),
acc_top_pct_eo = mean(proxy_sims_results$acc_top_pct_eo, na.rm = TRUE),
acc_top_pct_eo_sd = sd(proxy_sims_results$acc_top_pct_eo, na.rm = TRUE),
effpct_top_pct_eo = mean(proxy_sims_results$effpct_top_pct_eo, na.rm = TRUE),
effpct_top_pct_eo_sd = sd(proxy_sims_results$effpct_top_pct_eo, na.rm = TRUE),
recall_top_pct_e8 = mean(proxy_sims_results$recall_top_pct_e8, na.rm = TRUE),
recall_top_pct_e8_sd = sd(proxy_sims_results$recall_top_pct_e8, na.rm = TRUE),
prec_top_pct_e8 = mean(proxy_sims_results$prec_top_pct_e8, na.rm = TRUE),
prec_top_pct_e8_sd = sd(proxy_sims_results$prec_top_pct_e8, na.rm = TRUE),
acc_top_pct_e8 = mean(proxy_sims_results$acc_top_pct_e8, na.rm = TRUE),
acc_top_pct_e8_sd = sd(proxy_sims_results$acc_top_pct_e8, na.rm = TRUE),
effpct_top_pct_e8 = mean(proxy_sims_results$effpct_top_pct_e8, na.rm = TRUE),
effpct_top_pct_e8_sd = sd(proxy_sims_results$effpct_top_pct_e8, na.rm = TRUE),
recall_top_pct_ro = mean(proxy_sims_results$recall_top_pct_ro, na.rm = TRUE),
recall_top_pct_ro_sd = sd(proxy_sims_results$recall_top_pct_ro, na.rm = TRUE),
prec_top_pct_ro = mean(proxy_sims_results$prec_top_pct_ro, na.rm = TRUE),
prec_top_pct_ro_sd = sd(proxy_sims_results$prec_top_pct_ro, na.rm = TRUE),
acc_top_pct_ro = mean(proxy_sims_results$acc_top_pct_ro, na.rm = TRUE),
acc_top_pct_ro_sd = sd(proxy_sims_results$acc_top_pct_ro, na.rm = TRUE),
effpct_top_pct_ro = mean(proxy_sims_results$effpct_top_pct_ro, na.rm = TRUE),
effpct_top_pct_ro_sd = sd(proxy_sims_results$effpct_top_pct_ro, na.rm = TRUE),
recall_top_pct_r8 = mean(proxy_sims_results$recall_top_pct_r8, na.rm = TRUE),
recall_top_pct_r8_sd = sd(proxy_sims_results$recall_top_pct_r8, na.rm = TRUE),
prec_top_pct_r8 = mean(proxy_sims_results$prec_top_pct_r8, na.rm = TRUE),
prec_top_pct_r8_sd = sd(proxy_sims_results$prec_top_pct_r8, na.rm = TRUE),
acc_top_pct_r8 = mean(proxy_sims_results$acc_top_pct_r8, na.rm = TRUE),
acc_top_pct_r8_sd = sd(proxy_sims_results$acc_top_pct_r8, na.rm = TRUE),
effpct_top_pct_r8 = mean(proxy_sims_results$effpct_top_pct_r8, na.rm = TRUE),
effpct_top_pct_r8_sd = sd(proxy_sims_results$effpct_top_pct_r8, na.rm = TRUE)
)
# Step 3: Transform results into a cleaner format
summary_table <- data.frame(
Metric = c(
"Number of Leaves (k)", "Min Split", "Max Depth", "Alpha",
"Validation BIC", "Oracle R²", "Accuracy Test", "Relaxed Accuracy Test",
"Top Accuracy", "Bottom Accuracy", "Detect recall", "Detect effective %", "Time"
),
Rpart_Known_k_Mean = c(
summary_results$k_r8, summary_results$minbucket_r8, summary_results$maxdepth_r8,
summary_results$cp_8, summary_results$validation_mse_r8, summary_results$oracle_r2_r8,
summary_results$acc_test_r8, summary_results$relaxed_acc_test_r8,
summary_results$top_acc_r8, summary_results$bottom_acc_r8, summary_results$recall_top_pct_r8, summary_results$effpct_top_pct_r8, summary_results$time_r8
),
Rpart_Known_k_SD = c(
summary_results$k_r8_sd, summary_results$minbucket_r8_sd, summary_results$maxdepth_r8_sd,
summary_results$cp_8_sd, summary_results$validation_mse_r8_sd, summary_results$oracle_r2_r8_sd,
summary_results$acc_test_r8_sd, summary_results$relaxed_acc_test_r8_sd,
summary_results$top_acc_r8_sd, summary_results$bottom_acc_r8_sd, summary_results$recall_top_pct_r8_sd, summary_results$effpct_top_pct_r8_sd, summary_results$time_r8_sd
),
Ev_Known_k_Mean = c(
summary_results$k_8, summary_results$minbucket_8, summary_results$maxdepth_8,
summary_results$alpha_8, summary_results$validation_mse_8, summary_results$oracle_r2_8,
summary_results$acc_test_8, summary_results$relaxed_acc_test_8,
summary_results$top_acc_8, summary_results$bottom_acc_8,summary_results$recall_top_pct_e8, summary_results$effpct_top_pct_e8, summary_results$time_8
),
Ev_Known_k_SD = c(
summary_results$k_8_sd, summary_results$minbucket_8_sd, summary_results$maxdepth_8_sd,
summary_results$alpha_8_sd, summary_results$validation_mse_8_sd, summary_results$oracle_r2_8_sd,
summary_results$acc_test_8_sd, summary_results$relaxed_acc_test_8_sd,
summary_results$top_acc_8_sd, summary_results$bottom_acc_8_sd,summary_results$recall_top_pct_e8_sd, summary_results$effpct_top_pct_e8_sd, summary_results$time_8_sd
),
Rpart_Any_k_Mean = c(
summary_results$k_rany, summary_results$minbucket_rany, summary_results$maxdepth_rany,
summary_results$cp_any, summary_results$validation_mse_rany, summary_results$oracle_r2_rany,
NA, NA, # Accuracy metrics not available for Any k
summary_results$top_acc_rany, summary_results$bottom_acc_rany, summary_results$recall_top_pct_ro, summary_results$effpct_top_pct_ro, summary_results$time_rany
),
Rpart_Any_k_SD = c(
summary_results$k_rany_sd, summary_results$minbucket_rany_sd, summary_results$maxdepth_rany_sd,
summary_results$cp_any_sd, summary_results$validation_mse_rany_sd, summary_results$oracle_r2_rany_sd,
NA, NA, # Accuracy metrics not available for Any k
summary_results$top_acc_rany_sd, summary_results$bottom_acc_rany_sd, summary_results$recall_top_pct_ro_sd, summary_results$effpct_top_pct_ro_sd, summary_results$time_rany_sd
),
Ev_Any_k_Mean = c(
summary_results$k_any, summary_results$minbucket_any, summary_results$maxdepth_any,
summary_results$alpha_any, summary_results$validation_mse_any, summary_results$oracle_r2_any,
NA, NA, # Accuracy metrics not available for Any k
summary_results$top_acc_any, summary_results$bottom_acc_any,summary_results$recall_top_pct_eo, summary_results$effpct_top_pct_eo, summary_results$time_any
),
Ev_Any_k_SD = c(
summary_results$k_any_sd, summary_results$minbucket_any_sd, summary_results$maxdepth_any_sd,
summary_results$alpha_any_sd, summary_results$validation_mse_any_sd, summary_results$oracle_r2_any_sd,
NA, NA, # Accuracy metrics not available for Any k
summary_results$top_acc_any_sd, summary_results$bottom_acc_any_sd, summary_results$recall_top_pct_eo_sd, summary_results$effpct_top_pct_eo_sd, summary_results$time_any_sd
),
group = c('Hyperparam.', 'Hyperparam.', 'Hyperparam.', 'Hyperparam.',
'Efficiency', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.',
'Efficiency', 'Efficiency', 'Efficiency') |> factor(levels = c('Hyperparam.', 'Efficiency', 'Oracle perf.'))
)
library(knitr)
library(kableExtra)
## round
summary_table[, 2:9] <- round(summary_table[,2:9], 3)
table_to_g <- summary_table[,c('Metric',
"Rpart_Known_k_Mean", "Rpart_Known_k_SD",
"Ev_Known_k_Mean", "Ev_Known_k_SD",
"Rpart_Any_k_Mean", "Rpart_Any_k_SD",
"Ev_Any_k_Mean", "Ev_Any_k_SD", "group")] |>
arrange(group)
# Add group column to the beginning
table_to_g <- table_to_g |> relocate(group)
# Row breaks after last row of each group
group_lines <- c(5, 9)
# Create table with custom header
kbl(table_to_g |> dplyr::select(-group), col.names = NULL,
caption = "Results of experimental testing. For each of the $N = 100$ samples, we obtained the best regression trees when forcing $k = 8$ or when $k$ was part of the tuned hyperparameters.",
label = "experiment") %>%
add_header_above(c(
"Metric" = 1,
"Mean" = 1, "SD" = 1,
"Mean" = 1, "SD" = 1,
"Mean" = 1, "SD" = 1,
"Mean" = 1, "SD" = 1
)) %>%
add_header_above(c(
" " = 1,
"Greedy tree \n (rpart)" = 2,
"Optimal tree \n (evtree)" = 2,
"Greedy tree \n (rpart)" = 2,
"Optimal tree \n (evtree)" = 2
)) %>% add_header_above(c(
" " = 1,
"Known k" = 4,
"Unknown k" = 4
), escape = FALSE) %>%
group_rows(index = table(table_to_g$group)) %>%
row_spec(group_lines, extra_css = "border-top: 2px solid black;") %>%
kable_styling(full_width = FALSE, bootstrap_options = c("striped", "hover"))Known k |
Unknown k |
|||||||
|---|---|---|---|---|---|---|---|---|
Greedy tree (rpart) |
Optimal tree (evtree) |
Greedy tree (rpart) |
Optimal tree (evtree) |
|||||
Metric |
Mean |
SD |
Mean |
SD |
Mean |
SD |
Mean |
SD |
| Hyperparam. | ||||||||
| Number of Leaves (k) | 8.000 | 0.000 | 8.000 | 0.000 | 14.290 | 1.192 | 12.290 | 1.908 |
| Min Split | 0.037 | 0.010 | 0.037 | 0.010 | 0.032 | 0.005 | 0.031 | 0.005 |
| Max Depth | 3.150 | 0.359 | 3.150 | 0.359 | 3.990 | 0.100 | 3.970 | 0.171 |
| Alpha | 0.005 | 0.003 | 1.470 | 0.502 | 0.000 | 0.000 | 1.390 | 0.490 |
| Efficiency | ||||||||
| Validation BIC | 90.892 | 155.067 | Inf | NaN | 39.851 | 165.704 | 7.816 | 158.082 |
| Detect recall | 0.951 | 0.118 | 0.959 | 0.104 | 0.907 | 0.153 | 0.926 | 0.136 |
| Detect effective % | 0.171 | 0.068 | 0.183 | 0.072 | 0.152 | 0.042 | 0.161 | 0.052 |
| Time | 0.016 | 0.007 | 23.083 | 7.086 | 0.015 | 0.007 | 43.037 | 14.681 |
| Oracle perf. | ||||||||
| Oracle R² | 0.905 | 0.036 | 0.904 | 0.038 | 0.897 | 0.034 | 0.894 | 0.035 |
| Accuracy Test | 0.601 | 0.165 | 0.649 | 0.167 | NA | NA | NA | NA |
| Relaxed Accuracy Test | 0.971 | 0.051 | 0.976 | 0.047 | NA | NA | NA | NA |
| Top Accuracy | 0.941 | 0.076 | 0.932 | 0.079 | 0.957 | 0.064 | 0.951 | 0.067 |
| Bottom Accuracy | 0.942 | 0.077 | 0.939 | 0.076 | 0.967 | 0.056 | 0.964 | 0.051 |